import os
import json
import torch
import numpy as np
from scipy.stats import pearsonr

# ---------- 0. 路径 ----------
subj = 2
output_dir = f"./subj0{subj}_prior_vae_eye"
predicted_path = os.path.join(
    output_dir,
    "predicted_fmri_ep150_step100_repeat5.pt"   # 对应你之前保存的文件名
)
# ---------- 1. 加载数据 ----------
predicted_fmri = torch.load(predicted_path).cpu()          # (N, num_voxels)
# 注意：test_voxel_mean 需与你生成 predicted_fmri 时保持一致
# 如果 test_voxel_mean 没有单独保存，需要从原推理脚本里导出并存盘；
# 这里假设你已把它存成了 test_voxel_mean.pt
test_voxel_mean_path = os.path.join(output_dir, "test_voxel_mean.pt")
test_voxel_mean = torch.load(test_voxel_mean_path).cpu()    # (N, num_voxels)

# ---------- 2. 计算指标 ----------
mse_list, pearson_list, r2_list = [], [], []

for pred, true in zip(predicted_fmri, test_voxel_mean):
    pred_np = pred.numpy()
    true_np = true.numpy()

    # MSE
    mse = float(np.mean((pred_np - true_np) ** 2))
    mse_list.append(mse)

    # Pearson r
    
    r, _ = pearsonr(pred_np, true_np)
    pearson_list.append(float(r))
    

    # R²
    ss_res = np.sum((true_np - pred_np) ** 2)
    ss_tot = np.sum((true_np - np.mean(true_np)) ** 2)
    r2 = 1 - ss_res / (ss_tot + 1e-8)
    r2_list.append(float(r2))

# ---------- 3. 打印并保存 ----------
metrics = {
    "mse": float(np.mean(mse_list)),
    "pearson_r": float(np.mean(pearson_list)),
    "r2": float(np.mean(r2_list))
}

print(json.dumps(metrics, indent=4))

with open(os.path.join(output_dir, "final_metrics.json"), "w") as f:
    json.dump(metrics, f, indent=4)